//    Copyright 2024 FP6-LLM authors
//
//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at
//
//        http://www.apache.org/licenses/LICENSE-2.0
//
//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.
//
// This file is adapted from
// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu

#include "kernel_matmul.cuh"
#include "kernel_reduction.cuh"
#include "core/registration.h"

#include <stdio.h>
#include <assert.h>

namespace aphrodite {

template <typename TilingConfig, typename OutputDataType, int EXPONENT,
          int MANTISSA>
static void Kernel_Ex(cudaStream_t stream, const uint4* Weight,
                      const half* Scales, const half* B, OutputDataType* C,
                      const size_t M_Global, const size_t N_Global,
                      const size_t K_Global, int Split_K) {
#ifdef DEBUG_MODE
  printf("\n");
  printf("Launcher.cu->Kernel_Ex():\n");
  printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global,
         Split_K);
  printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M,
         TilingConfig::TILE_K, TilingConfig::TILE_N);
#endif
  static size_t SHMEM_SZ =
      max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE,
          TilingConfig::SMEM_SIZE_C_TILE);
  cudaFuncSetAttribute(
      QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>,
      cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
  size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
  size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
  dim3 GridDim(dimN, dimM, 1);
  dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
//
#ifdef DEBUG_MODE
  printf(
      "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, "
      "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
      GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z,
      SHMEM_SZ);
  printf("\n");
#endif
  QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>
      <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global,
                                                N_Global, K_Global, Split_K);
}

template <int EXPONENT, int MANTISSA>
cudaError_t fpx_linear_kernel(
    cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B,
    half* C, const size_t M_Global, const size_t N_Global,
    const size_t K_Global,
    float* Reduction_Workspace,  // Reduction_Workspace_Size = Split_K *
                                 // M_Global * N_Global * sizeof(fp32)
    int Split_K) {
  assert(M_Global % 256 == 0);
  assert(K_Global % 64 == 0);
  assert(N_Global > 0);

  // Work around to support more N shapes:
  size_t N_PowerOf2;
  if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8;
  if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16;
  if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32;
  if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64;
  if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128;
  if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128;

  if (Split_K == 1) {
    switch (N_PowerOf2) {
      case 8:
        Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
      case 16:
        Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
      case 32:
        Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
      case 64:
        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
      case 128:
        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
      default:
        if (N_PowerOf2 % 128 != 0) {
          printf("FP6LLM_API Error: Unsupported N dimension %zu!\n",
                 N_PowerOf2);
          return cudaErrorUnknown;
        }
        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
            Split_K);
        break;
    }
  } else {
    switch (N_PowerOf2) {
      case 8:
        Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
      case 16:
        Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
      case 32:
        Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
      case 64:
        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
      case 128:
        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
      default:
        if (N_PowerOf2 % 128 != 0) {
          printf("FP6LLM_API Error: Unsupported N dimension %zu!\n",
                 N_PowerOf2);
          return cudaErrorUnknown;
        }
        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
            K_Global, Split_K);
        break;
    }
    // Reduction for SplitK
    dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1,
                 1);
    dim3 BlockDim(WARP_SIZE, 1, 1);
    SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(
        C, Reduction_Workspace, M_Global, N_Global, Split_K);
  }
  return cudaGetLastError();
}
}  // namespace aphrodite

#include <torch/all.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>

// MODIFICATION NOTE: dtype of _weights is changed to uint8
/*
Computes FPx-FP16 GEMM (PyTorch interface).
[Mathematical Formula]
Standard definition of linear layer:    Out = In * trans(W), where In, Out, and
W are stored in row-major. After Equivalent transformation    :    trans(Out) =
W * trans(In). Note that we do not perform "transpose" during runtime, we
instead interpret the In/Out as column-major matrices when calling our CUDA
kernel. [Inputs] _in_feats:  tensor of shape [B, IC];                  // half
  _weights:   int tensor of shape [OC, IC // 8 * x];    // x UINT8 words
contains 8 FPx weights. _scales:    tensor of shape [OC];                     //
half splitK:     splitting the MatMul problem along K dimension for higher GPU
utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half
*/
torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
                                          torch::Tensor _in_feats,
                                          torch::Tensor _weights,
                                          torch::Tensor _scales,
                                          int64_t splitK = 1) {
  const int64_t NBITS = 1 + EXPONENT + MANTISSA;
  int num_in_feats = _in_feats.size(0);
  int num_in_channels = _in_feats.size(1);
  int num_out_channels = _weights.size(0);
  TORCH_CHECK(num_in_channels % 64 == 0,
              "Expected in_features to be a multiple of 64, but received ",
              num_in_channels);
  TORCH_CHECK((num_in_channels / 8 * NBITS) ==
              _weights.size(1));  // Making sure the K dimension is matched.
  //
  int M = num_out_channels;
  int K = num_in_channels;
  int N = num_in_feats;
  // Input Tensors
  auto weight = reinterpret_cast<const uint4*>(
      _weights.data_ptr<uint8_t>());  // weights is [OC, IC] but in FP6.
  auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
  auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
  // Output Tensors
  auto options = torch::TensorOptions()
                     .dtype(_in_feats.dtype())
                     .device(_in_feats.device());
  at::Tensor _out_feats =
      torch::empty({num_in_feats, num_out_channels}, options);
  auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());

  options =
      torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
  at::Tensor _workspace =
      torch::empty({splitK, num_in_feats, num_out_channels}, options);
  auto Reduction_Workspace = reinterpret_cast<float*>(
      _workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K *
                                      // M_Global * N_Global * sizeof(fp32)

  // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default
  // stream (0) this fixes problem with CUDA graphs when used with
  // torch.compile()
  auto stream = at::cuda::getCurrentCUDAStream();

  /*
   The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit
   */

  // FP2
  if (EXPONENT == 1 && MANTISSA == 0)
    aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);

  // FP3
  else if (EXPONENT == 1 && MANTISSA == 1)
    aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 2 && MANTISSA == 0)
    aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);

  // FP4
  else if (EXPONENT == 1 && MANTISSA == 2)
    aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 3 && MANTISSA == 0)
    aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 2 && MANTISSA == 1)
    aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  // FP5
  else if (EXPONENT == 1 && MANTISSA == 3)
    aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 2 && MANTISSA == 2)
    aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 3 && MANTISSA == 1)
    aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 4 && MANTISSA == 0)
    aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);

  // FP6
  else if (EXPONENT == 1 && MANTISSA == 4)
    aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 2 && MANTISSA == 3)
    aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 3 && MANTISSA == 2)
    aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 4 && MANTISSA == 1)
    aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 5 && MANTISSA == 0)
    aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  // FP7
  else if (EXPONENT == 1 && MANTISSA == 5)
    aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 2 && MANTISSA == 4)
    aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 3 && MANTISSA == 3)
    aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 4 && MANTISSA == 2)
    aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);
  else if (EXPONENT == 5 && MANTISSA == 1)
    aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats,
                                       out_feats, M, N, K, Reduction_Workspace,
                                       splitK);

  else
    TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA,
                " is not supported.");

  return _out_feats;
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
  m.impl("fp_eXmY_linear_forward_cuda", &fp_eXmY_linear_forward_cuda);
}
